import copy
import torch
import numpy as np
import time
from flcore.clients.clientbase import Client
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity
import random
from torch.utils.data import DataLoader, TensorDataset
import math
from sklearn.preprocessing import label_binarize
from sklearn import metrics
import os


class clientFABLE(Client):
    def __init__(self, args, id, train_samples, test_samples, **kwargs):
        super().__init__(args, id, train_samples, test_samples, **kwargs)
        self.anchors = None
        self.anchors_num = args.anchors_number
        self.linear_anchor = args.linear_anchor
        if self.linear_anchor:
            self.linear_transform = nn.Linear(in_features=self.model.head.in_features, out_features=self.model.head.in_features, bias=False).to(args.device)
            nn.init.normal_(self.linear_transform.weight, mean=0.0, std=0.01)
        
    def train(self):            
        trainloader = self.load_train_data()
        if self.anchors is None:
            self.anchors, self.anchors_label = self.get_anchors(trainloader)
        self.model.train()
        
        start_time = time.time()

        max_local_epochs = self.local_epochs
        if self.train_slow:
            max_local_epochs = np.random.randint(1, max_local_epochs // 2)
               
        with torch.no_grad():
            anchors_representation = self.model.base(self.anchors)
            anchors_representation = F.normalize(anchors_representation, p=2, dim=-1).detach()
            
        for epoch in range(max_local_epochs):
            for i, (x, y) in enumerate(trainloader):
                x = x[0].to(self.device) if isinstance(x, list) else x.to(self.device)
                y = y.to(self.device)

                absolute_representation = self.model.base(x)
                absolute_representation = F.normalize(absolute_representation, p=2, dim=-1)

                if self.linear_anchor:
                    linear_representation = self.linear_transform(anchors_representation.detach())
                    transformation_representation = torch.matmul(absolute_representation, linear_representation.T)
                else:
                    transformation_representation = torch.matmul(absolute_representation, anchors_representation.T)
                
                output = self.model.head(transformation_representation)

                loss = self.loss(output, y)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

        if self.learning_rate_decay:
            self.learning_rate_scheduler.step()

        self.train_time_cost['num_rounds'] += 1
        self.train_time_cost['total_cost'] += time.time() - start_time

    def get_anchors(self, trainloader):
        images = []

        for i, (x, y) in enumerate(trainloader):
            x = x[0].to(self.device) if isinstance(x, list) else x.to(self.device)
            y = y.to(self.device)

            images.append(x)
        images = torch.cat(images, dim=0)

        anchors_index = np.random.choice(images.size(0), self.anchors_num, replace=False)
        anchors = images[anchors_index]
        anchors_labels = labels[anchors_index]

        return anchors


    def set_parameters(self, global_model_base):
        for new_param, old_param in zip(global_model_base.parameters(), self.model.base.parameters()):
            old_param.data = new_param.data.clone()


    def test_metrics(self, representation_type="absolute"):
        testloaderfull = self.load_test_data()
        self.model.eval()

        test_acc = 0
        test_num = 0
        y_prob = []
        y_true = []
        
        with torch.no_grad():
            for x, y in testloaderfull:
                x = x[0].to(self.device) if isinstance(x, list) else x.to(self.device)
                y = y.to(self.device)
                
                if self.anchors is None:
                    trainloader = self.load_train_data()  
                    self.anchors, self.anchors_label = self.get_anchors(trainloader)

                anchors_representation = self.model.base(self.anchors)
                anchors_representation = F.normalize(anchors_representation, p=2, dim=-1)
                absolute_representation = self.model.base(x)
                absolute_representation = F.normalize(absolute_representation, p=2, dim=-1)
                if self.linear_anchor:
                    linear_representation = self.linear_transform(anchors_representation)
                    transformation_representation = torch.matmul(absolute_representation, linear_representation.T)
                else:
                    transformation_representation = torch.matmul(absolute_representation, anchors_representation.T)
                
                output = self.model.head(transformation_representation)

                test_acc += (torch.sum(torch.argmax(output, dim=1) == y)).item()
                test_num += y.shape[0]

                y_prob.append(output.detach().cpu().numpy())
                nc = self.num_classes
                if self.num_classes == 2:
                    nc += 1
                lb = label_binarize(y.detach().cpu().numpy(), classes=np.arange(nc))
                if self.num_classes == 2:
                    lb = lb[:, :2]
                y_true.append(lb)

        y_prob = np.concatenate(y_prob, axis=0)
        y_true = np.concatenate(y_true, axis=0)

        auc = metrics.roc_auc_score(y_true, y_prob, average='micro')
        
        return test_acc, test_num, auc

    def train_metrics(self, representation_type="absolute"):
        trainloader = self.load_train_data()
        self.model.eval()

        train_acc = 0
        train_num = 0
        losses = 0

        with torch.no_grad():
            for x, y in trainloader:
                if type(x) == type([]):
                    x[0] = x[0].to(self.device)
                else:
                    x = x.to(self.device)
                y = y.to(self.device)

                if self.anchors is None:
                    self.anchors, self.anchors_label = self.get_anchors(trainloader)

                anchors_representation = self.model.base(self.anchors)
                anchors_representation = F.normalize(anchors_representation, p=2, dim=-1)
                if self.linear_anchor:
                    linear_representation = self.linear_transform(anchors_representation)    
                absolute_representation = self.model.base(x)
                absolute_representation = F.normalize(absolute_representation, p=2, dim=-1)

                transformation_representation = torch.matmul(absolute_representation, linear_representation.T)
                output = self.model.head(transformation_representation)

                loss = self.loss(output, y)

                train_num += y.shape[0]
                losses += loss.item() * y.shape[0]

        return losses, train_num

    def set_parameters(self, model):
        for new_param, old_param in zip(model.parameters(), self.model.base.parameters()):
            old_param.data = new_param.data.clone()


